library(readstata13) # package to read Stata 13 dta files
library(tidyverse)
library(lubridate) # Working with dates
library(rstan)
options(mc.cores = parallel::detectCores())
rstan_options(auto_write = T)
set.seed(1)

stan_data <- readRDS("Results/stan_data_for_fit_2021_01_18.rds")

args <- commandArgs(trailingOnly = TRUE)

vb_fit = readRDS(args[1])$vb_fit

filename <- gsub("vb", "MAP", args[1])

median_vb_estimates = vb_fit@model_pars %>% 
  (function(x) x[x!="lp__"]) %>% 
  lapply(function(parname){
    summary(vb_fit, pars = parname) %>% 
      .$summary %>% 
      as_tibble() %>% 
      .$`50%` %>% 
      list(pp = .) %>% 
      setNames(parname)
  }) %>% 
  Reduce(c, .)

Poisson_model_lasso_MAP <- "
data {
  int<lower=0> N;         // number of data points
  int<lower=0> N_municipalities;         // number of data points
  int<lower=0> NbDeath[N];
  int<lower=0> Population[N];
  int<lower=0> Municipality_code[N];
  int<lower=0> covid[N]; // standard error of effect estimates
  real<lower=0> lambda;
}
parameters {
  real<upper=0> log_hc[N_municipalities];
  real<lower=0, upper=1> hcp[N_municipalities];
  real mu;
  real<lower=0> sigma;
}
  transformed parameters {
    real hc[N_municipalities];
    hc = exp(log_hc);
  }
model {
  real lambda_eff[N];
  
  mu ~ normal(-5, 5);
  sigma ~ normal(1, 20);
  
  log_hc ~ normal(mu, sigma);
  hcp ~ exponential(lambda);

  for (i in 1:N) {
    lambda_eff[i] = Population[i]*(hc[Municipality_code[i]] + covid[i]*hcp[Municipality_code[i]]);
  }

  NbDeath ~ poisson(lambda_eff);
}
"
cat(file = "Scripts/Poisson_model_lasso_MAP.stan", Poisson_model_lasso_MAP)

system.time(m <- stan_model(file = "Scripts/Poisson_model_lasso_MAP.stan"))

stan_data$lambda = median_vb_estimates$lambda

MLE_fit = optimizing(object = m, data = stan_data, verbose = T, init = median_vb_estimates, iter = 20000)

print(paste("saving to", filename))

saveRDS(MLE_fit, file = filename)
